{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W_tvPdyfA-BL"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "0O_LFhwSBCjm"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9-3Pry4jh1-E"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_for_deep_learning/l06c01_tensorflow_hub_and_transfer_learning.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_for_deep_learning/l06c01_tensorflow_hub_and_transfer_learning.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NxjpzKTvg_dd"
      },
      "source": [
        "# TensorFlow Hub and Transfer Learning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "crU-iluJIEzw"
      },
      "source": [
        "[TensorFlow Hub](http://tensorflow.org/hub) is an online repository of already trained TensorFlow models that you can use.\n",
        "These models can either be used as is, or they can be used for Transfer Learning.\n",
        "\n",
        "Transfer learning is a process where you take an existing trained model, and extend it to do additional work. This involves leaving the bulk of the model unchanged, while adding and retraining the final layers, in order to get a different set of possible outputs.\n",
        "\n",
        "In this Colab we will do both.\n",
        "\n",
        "Here, you can see all the models available in [TensorFlow Module Hub](https://tfhub.dev/).\n",
        "\n",
        "## Concepts that will be covered in this Colab\n",
        "\n",
        "1. Use a TensorFlow Hub model for prediction.\n",
        "2. Use a TensorFlow Hub model for Dogs vs. Cats dataset.\n",
        "3. Do simple transfer learning with TensorFlow Hub.\n",
        "\n",
        "Before starting this Colab, you should reset the Colab environment by selecting `Runtime -> Reset all runtimes...` from menu above."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7RVsYZLEpEWs"
      },
      "source": [
        "# Imports\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZUCEcRdhnyWn"
      },
      "source": [
        "Some normal imports we've seen before. The new one is importing tensorflow_hub which was installed above, and which this Colab will make heavy use of."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8z5hqr0hHtLv"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WZnAHGETHu7e"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pylab as plt\n",
        "\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "from tensorflow.keras import layers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FVM2fKGEHIJN"
      },
      "outputs": [],
      "source": [
        "import logging\n",
        "logger = tf.get_logger()\n",
        "logger.setLevel(logging.ERROR)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s4YuF5HvpM1W"
      },
      "source": [
        "# Part 1: Use a TensorFlow Hub MobileNet for prediction"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Sh2sPc10V0b"
      },
      "source": [
        "In this part of the Colab, we'll take a trained model, load it into to Keras, and try it out.\n",
        "\n",
        "The model that we'll use is MobileNet v2 (but any model from [tf2 compatible image classifier URL from tfhub.dev](https://tfhub.dev/s?q=tf2\u0026module-type=image-classification) would work)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xEY_Ow5loN6q"
      },
      "source": [
        "## Download the classifier\n",
        "\n",
        "Download the MobileNet model and create a Keras model from it.\n",
        "MobileNet is expecting images of 224 $\\times$ 224 pixels, in 3 color channels (RGB)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y_6bGjoPtzau"
      },
      "outputs": [],
      "source": [
        "CLASSIFIER_URL =\"https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2\"\n",
        "IMAGE_RES = 224\n",
        "\n",
        "model = tf.keras.Sequential([\n",
        "    hub.KerasLayer(CLASSIFIER_URL, input_shape=(IMAGE_RES, IMAGE_RES, 3))\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pwZXaoV0uXp2"
      },
      "source": [
        "## Run it on a single image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TQItP1i55-di"
      },
      "source": [
        "MobileNet has been trained on the ImageNet dataset. ImageNet has 1000 different output classes, and one of them is military uniforms.\n",
        "Let's get an image containing a military uniform that is not part of ImageNet, and see if our model can predict that it is a military uniform."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w5wDjXNjuXGD"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import PIL.Image as Image\n",
        "\n",
        "grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')\n",
        "grace_hopper = Image.open(grace_hopper).resize((IMAGE_RES, IMAGE_RES))\n",
        "grace_hopper "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BEmmBnGbLxPp"
      },
      "outputs": [],
      "source": [
        "grace_hopper = np.array(grace_hopper)/255.0\n",
        "grace_hopper.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Ic8OEEo2b73"
      },
      "source": [
        "Remember, models always want a batch of images to process. So here, we add a batch dimension, and pass the image to the model for prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EMquyn29v8q3"
      },
      "outputs": [],
      "source": [
        "result = model.predict(grace_hopper[np.newaxis, ...])\n",
        "result.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NKzjqENF6jDF"
      },
      "source": [
        "The result is a 1001 element vector of logits, rating the probability of each class for the image.\n",
        "\n",
        "So the top class ID can be found with argmax. But how can we know what class this actually is and in particular if that class ID in the ImageNet dataset denotes a military uniform or something else?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rgXb44vt6goJ"
      },
      "outputs": [],
      "source": [
        "predicted_class = np.argmax(result[0], axis=-1)\n",
        "predicted_class"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YrxLMajMoxkf"
      },
      "source": [
        "## Decode the predictions\n",
        "\n",
        "To see what our predicted_class is in the ImageNet dataset, download the ImageNet labels and fetch the row that the model predicted."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ij6SrDxcxzry"
      },
      "outputs": [],
      "source": [
        "labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')\n",
        "imagenet_labels = np.array(open(labels_path).read().splitlines())\n",
        "\n",
        "plt.imshow(grace_hopper)\n",
        "plt.axis('off')\n",
        "predicted_class_name = imagenet_labels[predicted_class]\n",
        "_ = plt.title(\"Prediction: \" + predicted_class_name.title())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a6TNYYAM4u2-"
      },
      "source": [
        "Bingo. Our model correctly predicted military uniform!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "amfzqn1Oo7Om"
      },
      "source": [
        "# Part 2: Use a TensorFlow Hub models for the Cats vs. Dogs dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K-nIpVJ94xrw"
      },
      "source": [
        "Now we'll use the full MobileNet model and see how it can perform on the Dogs vs. Cats dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z93vvAdGxDMD"
      },
      "source": [
        "## Dataset\n",
        "\n",
        "We can use TensorFlow Datasets to load the Dogs vs Cats dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DrIUV3V0xDL_"
      },
      "outputs": [],
      "source": [
        "(train_examples, validation_examples), info = tfds.load(\n",
        "    'cats_vs_dogs', \n",
        "    with_info=True, \n",
        "    as_supervised=True, \n",
        "    split=['train[:80%]', 'train[80%:]'],\n",
        ")\n",
        "\n",
        "num_examples = info.splits['train'].num_examples\n",
        "num_classes = info.features['label'].num_classes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UlFZ_hwjCLgS"
      },
      "source": [
        "The images in the Dogs vs. Cats dataset are not all the same size."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W4lDPkn2cpWZ"
      },
      "outputs": [],
      "source": [
        "for i, example_image in enumerate(train_examples.take(3)):\n",
        "  print(\"Image {} shape: {}\".format(i+1, example_image[0].shape))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mbgpD3E6gM2P"
      },
      "source": [
        "So we need to reformat all images to the resolution expected by MobileNet (224, 224).\n",
        "\n",
        "The `.repeat()` and `steps_per_epoch` here is not required, but saves ~15s per epoch, since the shuffle-buffer only has to cold-start once."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "we_ftzQxNf7e"
      },
      "outputs": [],
      "source": [
        "def format_image(image, label):\n",
        "  image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0\n",
        "  return image, label\n",
        "\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "train_batches      = train_examples.shuffle(num_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
        "validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE).prefetch(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0gTN7M_GxDLx"
      },
      "source": [
        "## Run the classifier on a batch of images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O3fvrZR8xDLv"
      },
      "source": [
        "Remember our `model` object is still the full MobileNet model trained on ImageNet, so it has 1000 possible output classes.\n",
        "ImageNet has a lot of dogs and cats in it, so let's see if it can predict the images in our Dogs vs. Cats dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kii_jWZYOn0B"
      },
      "outputs": [],
      "source": [
        "image_batch, label_batch = next(iter(train_batches.take(1)))\n",
        "image_batch = image_batch.numpy()\n",
        "label_batch = label_batch.numpy()\n",
        "\n",
        "result_batch = model.predict(image_batch)\n",
        "\n",
        "predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]\n",
        "predicted_class_names"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QmvSWg9nxDLa"
      },
      "source": [
        "The labels seem to match names of Dogs and Cats. Let's now plot the images from our Dogs vs Cats dataset and put the ImageNet labels next to them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IXTB22SpxDLP"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10,9))\n",
        "for n in range(30):\n",
        "  plt.subplot(6,5,n+1)\n",
        "  plt.subplots_adjust(hspace = 0.3)\n",
        "  plt.imshow(image_batch[n])\n",
        "  plt.title(predicted_class_names[n])\n",
        "  plt.axis('off')\n",
        "_ = plt.suptitle(\"ImageNet predictions\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JzV457OXreQP"
      },
      "source": [
        "# Part 3: Do simple transfer learning with TensorFlow Hub\n",
        "\n",
        "Let's now use TensorFlow Hub to do Transfer Learning.\n",
        "\n",
        "With transfer learning we reuse parts of an already trained model and change the final layer, or several layers, of the model, and then retrain those layers on our own dataset.\n",
        "\n",
        "In addition to complete models, TensorFlow Hub also distributes models without the last classification layer. These can be used to easily do transfer learning. We will continue using MobileNet v2 because in later parts of this course, we will take this model and deploy it on a mobile device using [TensorFlow Lite](https://www.tensorflow.org/lite). Any [image feature vector URL from tfhub.dev](https://tfhub.dev/s?module-type=image-feature-vector\u0026q=tf2) would work here.\n",
        "\n",
        "We'll also continue to use the Dogs vs Cats dataset, so we will be able to compare the performance of this model against the ones we created from scratch earlier.\n",
        "\n",
        "Note that we're calling the partial model from TensorFlow Hub (without the final classification layer) a `feature_extractor`. The reasoning for this term is that it will take the input all the way to a layer containing a number of features. So it has done the bulk of the work in identifying the content of an image, except for creating the final probability distribution. That is, it has extracted the features of the image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5wB030nezBwI"
      },
      "outputs": [],
      "source": [
        "URL = \"https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2\"\n",
        "feature_extractor = hub.KerasLayer(URL,\n",
        "                                   input_shape=(IMAGE_RES, IMAGE_RES,3))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pkSvAPvKOWg2"
      },
      "source": [
        "Let's run a batch of images through this, and see the final shape. 32 is the number of images, and 1280 is the number of neurons in the last layer of the partial model from TensorFlow Hub."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Of7i-35F09ls"
      },
      "outputs": [],
      "source": [
        "feature_batch = feature_extractor(image_batch)\n",
        "print(feature_batch.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CtFmF7A5E4tk"
      },
      "source": [
        "Freeze the variables in the feature extractor layer, so that the training only modifies the final classifier layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jg5ar6rcE4H-"
      },
      "outputs": [],
      "source": [
        "feature_extractor.trainable = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RPVeouTksO9q"
      },
      "source": [
        "## Attach a classification head\n",
        "\n",
        "Now wrap the hub layer in a `tf.keras.Sequential` model, and add a new classification layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mGcY27fY1q3Q"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "  feature_extractor,\n",
        "  layers.Dense(2)\n",
        "])\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OHbXQqIquFxQ"
      },
      "source": [
        "## Train the model\n",
        "\n",
        "We now train this model like any other, by first calling `compile` followed by `fit`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3n0Wb9ylKd8R"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "  optimizer='adam',\n",
        "  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "  metrics=['accuracy'])\n",
        "\n",
        "EPOCHS = 6\n",
        "history = model.fit(train_batches,\n",
        "                    epochs=EPOCHS,\n",
        "                    validation_data=validation_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "76as-K8-vFQJ"
      },
      "source": [
        "You can see we get ~97% validation accuracy, which is absolutely awesome. This is a huge improvement over the model we created in the previous lesson, where we were able to get ~83% accuracy. The reason for this difference is that MobileNet was carefully designed over a long time by experts, then trained on a massive dataset (ImageNet).\n",
        "\n",
        "Although not equivalent to TensorFlow Hub, you can check out how to create MobileNet in Keras [here](https://github.com/keras-team/keras-applications/blob/master/keras_applications/mobilenet.py).\n",
        "\n",
        "Let's plot the training and validation accuracy/loss graphs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d28dhbFpr98b"
      },
      "outputs": [],
      "source": [
        "acc = history.history['accuracy']\n",
        "val_acc = history.history['val_accuracy']\n",
        "\n",
        "loss = history.history['loss']\n",
        "val_loss = history.history['val_loss']\n",
        "\n",
        "epochs_range = range(EPOCHS)\n",
        "\n",
        "plt.figure(figsize=(8, 8))\n",
        "plt.subplot(1, 2, 1)\n",
        "plt.plot(epochs_range, acc, label='Training Accuracy')\n",
        "plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n",
        "plt.legend(loc='lower right')\n",
        "plt.title('Training and Validation Accuracy')\n",
        "\n",
        "plt.subplot(1, 2, 2)\n",
        "plt.plot(epochs_range, loss, label='Training Loss')\n",
        "plt.plot(epochs_range, val_loss, label='Validation Loss')\n",
        "plt.legend(loc='upper right')\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5zmoDisGvNye"
      },
      "source": [
        "What is a bit curious here is that validation performance is better than training performance, right from the start to the end of execution.\n",
        "\n",
        "One reason for this is that validation performance is measured at the end of the epoch, but training performance is the average values across the epoch.\n",
        "\n",
        "The bigger reason though is that we're reusing a large part of MobileNet which is already trained on Dogs and Cats images. While doing training, the network is still performing image augmentation on the training images, but not on the validation dataset. This means the training images may be harder to classify compared to the normal images in the validation dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kb__ZN8uFn-D"
      },
      "source": [
        "## Check the predictions\n",
        "\n",
        "To redo the plot from before, first get the ordered list of class names."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W_Zvg2i0fzJu"
      },
      "outputs": [],
      "source": [
        "class_names = np.array(info.features['label'].names)\n",
        "class_names"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Olg6MsNGJTL"
      },
      "source": [
        "Run the image batch through the model and convert the indices to class names."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fCLVCpEjJ_VP"
      },
      "outputs": [],
      "source": [
        "predicted_batch = model.predict(image_batch)\n",
        "predicted_batch = tf.squeeze(predicted_batch).numpy()\n",
        "predicted_ids = np.argmax(predicted_batch, axis=-1)\n",
        "predicted_class_names = class_names[predicted_ids]\n",
        "predicted_class_names"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CkGbZxl9GZs-"
      },
      "source": [
        "Let's look at the true labels and predicted ones."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nL9IhOmGI5dJ"
      },
      "outputs": [],
      "source": [
        "print(\"Labels: \", label_batch)\n",
        "print(\"Predicted labels: \", predicted_ids)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wC_AYRJU9NQe"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10,9))\n",
        "for n in range(30):\n",
        "  plt.subplot(6,5,n+1)\n",
        "  plt.subplots_adjust(hspace = 0.3)\n",
        "  plt.imshow(image_batch[n])\n",
        "  color = \"blue\" if predicted_ids[n] == label_batch[n] else \"red\"\n",
        "  plt.title(predicted_class_names[n].title(), color=color)\n",
        "  plt.axis('off')\n",
        "_ = plt.suptitle(\"Model predictions (blue: correct, red: incorrect)\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "l06c01_tensorflow_hub_and_transfer_learning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
